import jax.nn as nn
import jax
import jax.numpy as jnp

from functools import partial
from flax import linen as nn
from src.models.transformers import *
from src.models.rlt.rlt import *
from src.models.rlt.rlt_uoro import *
from src.models.rnns.rnn import TruncatedVanillaRNN,LSTM 
from src.utils import *
from typing import NamedTuple, Optional,Any,Sequence

# class LinearTransformerPredictor(nn.Module):
#     d_model:int
#     n_heads:int
#     d_ffc:int
#     n_layers:int
#     output_size:int
#     truncation:int
#     kernel_phi:Any


#     @nn.compact
#     def __call__(self,inputs):
#         input_dim=inputs.shape[-1]
#         inputs_concat=self.variable('state','inputs_concat',jnp.zeros,(self.truncation,input_dim))
#         inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
#         trf_model=LinearTransformer(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
#                                                     truncation=self.truncation,n_layers=self.n_layers,kernel_phi=self.kernel_phi)
#         trf_out=trf_model(inputs_concat.value)
#         pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
#         return pred


# class UniversalTransformerPredictor(nn.Module):
#     d_model:int
#     n_heads:int
#     d_ffc:int
#     n_layers:int
#     output_size:int
#     truncation:int
#     kernel_phi:Any

#     @nn.compact
#     def __call__(self,inputs):
#         input_dim=inputs.shape[-1]
#         inputs_concat=self.variable('state','inputs_concat',jnp.zeros,(self.truncation,input_dim))
#         inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
#         trf_model=UniversalLinearTransformer(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
#                                                     truncation=self.truncation,n_layers=self.n_layers,kernel_phi=self.kernel_phi)
#         trf_out=trf_model(inputs_concat.value)
#         pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
#         return pred



# class OULTTBBPTPredictor(nn.Module):
#     n_layers:int
#     d_model:int
#     d_ffc:int
#     n_heads:int 
#     kernel_dim:int
#     truncation:int
#     output_size:int
#     kernel_phi:Any
#     pos_emb_type:str='rotary'
#     use_layer_emb:str=True
#     update_rule:str='gated'


#     @nn.compact
#     def __call__(self,inputs):
#         """
#             Online Linear Transformer with full context trained using trucated TBPPT
#         Args:
#             inputs (_type_): shape (input_dim)
#         Returns:
#             jax.numpy.array: shape (1,)
#         """
#         input_dim=inputs.shape[-1]
#         inputs_concat=self.variable('state','inputs_concat',jnp.full,(self.truncation,input_dim),0.1)
#         inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
#         model=RecurrentUniversalLinearTransformer(n_layers=self.n_layers,d_model=self.d_model,d_ffc=self.d_ffc,
#                                         n_heads=self.n_heads,kernel_dim=self.kernel_dim,kernel_phi=self.kernel_phi,
#                                         update_rule=self.update_rule,pos_emb_type=self.pos_emb_type,use_layer_emb=self.use_layer_emb)
#         memory_state=self.variable('state','memory',model.initialize_memory)
#         trf_out,new_memory=model(inputs_concat.value,memory_state.value)
#         if not self.is_initializing():
#             memory_state.value=tree_index(new_memory,0)
#         pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
#         return pred


class RLTTBBPTPredictor(nn.Module):
    n_layers:int
    d_model:int
    d_ffc:int
    n_heads:int 
    kernel_dim:int
    truncation:int
    output_size:int
    kernel_phi:Any
    pos_emb_type:str
    use_layer_emb:str
    no_memory:bool
    update_rule:str='gated'

    @nn.compact
    def __call__(self,inputs):
        """
            Online Linear Transformer with full context trained using trucated TBPPT
        Args:
            inputs (_type_): shape (input_dim)
        Returns:
            jax.numpy.array: shape (1,)
        """
        input_dim=inputs.shape[-1]
        inputs_concat=self.variable('state','inputs_concat',jnp.zeros,(self.truncation,input_dim))
        inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
        model=RecurrentLinearTransformer(n_layers=self.n_layers,d_model=self.d_model,d_ffc=self.d_ffc,
                                        n_heads=self.n_heads,kernel_dim=self.kernel_dim,kernel_phi=self.kernel_phi,
                                        update_rule=self.update_rule,pos_emb_type=self.pos_emb_type,use_layer_emb=self.use_layer_emb,
                                        )
        memory_state=self.variable('state','memory',model.initialize_memory,self.n_layers,
                                   self.n_heads,self.d_model,self.kernel_dim,self.pos_emb_type,self.update_rule)
        trf_out,new_memory=model(inputs_concat.value,memory_state.value)
        if not self.is_initializing() and not self.no_memory:
            memory_state.value=tree_index(new_memory,0)
        pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
        return pred


class RLTTUOROPredictor(nn.Module):
    n_layers:int
    d_model:int
    d_ffc:int
    n_heads:int 
    kernel_dim:int
    truncation:int
    output_size:int
    kernel_phi:Any
    pos_emb_type:str
    use_layer_emb:str
    no_memory:bool
    update_rule:str='gated'
    mem_ax:int=0


    @nn.compact
    def __call__(self,inputs):
        """
            Online Linear Transformer with full context trained using trucated TBPPT
        Args:
            inputs (_type_): shape (input_dim)
        Returns:
            jax.numpy.array: shape (1,)
        """
        input_dim=inputs.shape[-1]
        inputs_concat=self.variable('state','inputs_concat',jnp.zeros,(self.truncation,input_dim))
        inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
        model=RecurrentLinearTransformerUORO(n_layers=self.n_layers,d_model=self.d_model,d_ffc=self.d_ffc,
                                        n_heads=self.n_heads,kernel_dim=self.kernel_dim,kernel_phi=self.kernel_phi,
                                        update_rule=self.update_rule,pos_emb_type=self.pos_emb_type,use_layer_emb=self.use_layer_emb,
                                        ret_mem_grad_ax=self.mem_ax
                                        )
        memory_state=self.variable('state','memory',RecurrentLinearTransformerUORO.initialize_memory
                                    ,self.n_layers,self.n_heads,self.d_model,self.kernel_dim,self.pos_emb_type,
                                    self.update_rule,self.kernel_phi)
        memory,memory_grads=memory_state.value
        trf_out,new_memory,new_memory_grads=model(inputs_concat.value,memory,memory_grads)
        if not self.is_initializing() and not self.no_memory:
            new_memory=tree_index(new_memory,self.mem_ax)
            memory_state.value=(new_memory,new_memory_grads)
        pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
        return pred
    

    


class VanillaRNNPredictor(nn.Module):
    d_model:int
    output_size:int
    truncation:int

    @nn.compact
    def __call__(self,inputs):
        model=TruncatedVanillaRNN(d_model=self.d_model,truncation=self.truncation)
        out=model(inputs)
        pred=nn.Sequential([nn.Dense(self.output_size)])(out)
        return pred

class LSTMTBBPTPredictor(nn.Module):
    d_model:int
    output_size:int
    truncation:int
    n_layers:int

    @nn.compact
    def __call__(self,inputs):
        """
            Online Linear Transformer with full context trained using trucated TBPPT
        Args:
            inputs (_type_): shape (input_dim)
        Returns:
            jax.numpy.array: shape (1,)
        """
        input_dim=inputs.shape[-1]
        inputs_concat=self.variable('state','inputs_concat',jnp.zeros,(self.truncation,input_dim))
        inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)
        for i in range(self.n_layers):
            model=LSTM(self.d_model)
            memory_state=self.variable('state','memory%d'%i,model.initialize_state)
            if i==0:
                trf_out,new_memory=model(inputs_concat.value,memory_state.value)
            else:
                trf_out,new_memory=model(trf_out,memory_state.value)
            if not self.is_initializing():
                memory_state.value=tree_index(new_memory,0)
        pred=nn.Sequential([nn.Dense(self.output_size)])(trf_out[-1])
        return pred